import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss

from core.rl.components.agent import BaseAgent
from core.utils.general_utils import ParamDict, AttrDict, map_dict, remove_spatial
from core.utils.pytorch_utils import map2torch, map2np, ten2ar, ar2ten

class BCAgent(BaseAgent):
    """Behavior Cloning agent"""
    def __init__(self, config):
        super().__init__(config)
        self._hp = self._default_hparams().overwrite(config)

        self.policy = self._hp.policy(self._hp.policy_params)
        self.policy_opt = self._get_optimizer(self._hp.optimizer, self.policy, self._hp.policy_lr)

        self.replay_buffer = self._hp.replay(self._hp.replay_params)

        self._update_steps = 0

    def _default_hparams(self):
        default_dict = ParamDict({
            'policy': None,     # policy class
            'policy_params': None,  # parameters for the policy class
            'policy_lr': 3e-4,          # learning rate for policy update
            'replay': None,             # replay buffer class
            'replay_params': None,      # parameters for replay buffer
            'loss_type': 'mse',
        })
        return super()._default_hparams().overwrite(default_dict)

    def _act(self, obs):
        obs = map2torch(self._obs_normalizer(obs), self._hp.device)
        policy_output = self._remove_batch(self.policy(obs[None]))
        if 'dist' in policy_output:
            del policy_output['dist']
        return map2np(policy_output)

    def update(self, experience_batch):
        for _ in range(int(self._hp.update_iterations)):
            experience_batch = self.replay_buffer.sample(n_samples=self._hp.batch_size)
            experience_batch.observation = self._obs_normalizer(experience_batch.observation)
            experience_batch = map2torch(experience_batch, self._hp.device)

            observation = experience_batch.observation
            policy_output = self.policy(observation)

            action_loss = 0.5 * (policy_output.action - experience_batch.action).pow(2).mean()
            # action_loss = F.l1_loss(policy_output.action, experience_batch.action)
            # action_loss = 0.5 * (policy_output.action - experience_batch.action).pow(2).mean()
            self._perform_update(action_loss, self.policy_opt, self.policy)

            info = AttrDict(
                action_loss=action_loss,
            )
        return info

